{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gNlXrRnVGmb8"
      },
      "source": [
        "This colab has an [accompanying video](https://www.youtube.com/watch?v=o6JNHoGUXCo)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RXVkGtU5Hq9E"
      },
      "outputs": [],
      "source": [
        "import itertools as it\n",
        "from matplotlib import pyplot as plt\n",
        "import numpy as np\n",
        "import pyspiel\n",
        "\n",
        "from open_spiel.python.algorithms import exploitability\n",
        "from open_spiel.python import policy as policy_lib\n",
        "\n",
        "np.set_printoptions(precision=3, suppress=True, floatmode='fixed')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "shtvgpHzMrcN"
      },
      "outputs": [],
      "source": [
        "game = pyspiel.load_game('tic_tac_toe')\n",
        "state = game.new_initial_state()\n",
        "\n",
        "print(state)\n",
        "while not state.is_terminal():\n",
        "  action = np.random.choice(state.legal_actions())\n",
        "  print(f'Taking action {action} {state.action_to_string(action)}')\n",
        "  state.apply_action(action)\n",
        "  print(state)\n",
        "print(f'Game over; returns {state.returns()}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VOaIgtInHyUs"
      },
      "outputs": [],
      "source": [
        "game = pyspiel.load_game('kuhn_poker')\n",
        "print(game.get_type().pretty_print())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PwNaIqi4IlfL"
      },
      "outputs": [],
      "source": [
        "policy = policy_lib.TabularPolicy(game)\n",
        "print(policy.states_per_player)\n",
        "print(policy.action_probability_array)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WJJPURUH_n64"
      },
      "outputs": [],
      "source": [
        "def print_policy(policy):\n",
        "  for state, probs in zip(it.chain(*policy.states_per_player),\n",
        "                          policy.action_probability_array):\n",
        "    print(f'{state:6}   p={probs}')\n",
        "\n",
        "print_policy(policy)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qbEt7PpRkrpq"
      },
      "outputs": [],
      "source": [
        "print(exploitability.nash_conv(game, policy))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bYkVJGQnIxx4"
      },
      "outputs": [],
      "source": [
        "def new_reach(so_far, player, action_prob):\n",
        "  \"\"\"Returns new reach probabilities.\"\"\"\n",
        "  new = np.array(so_far)\n",
        "  new[player] *= action_prob\n",
        "  return new\n",
        "\n",
        "def calc_cfr(state, reach):\n",
        "  \"\"\"Updates regrets; returns utility for all players.\"\"\"\n",
        "  if state.is_terminal():\n",
        "    return state.returns()\n",
        "  elif state.is_chance_node():\n",
        "    return sum(prob * calc_cfr(state.child(action), new_reach(reach, -1, prob))\n",
        "               for action, prob in state.chance_outcomes())\n",
        "  else:\n",
        "    # We are at a player decision point.\n",
        "    player = state.current_player()\n",
        "    index = policy.state_index(state)\n",
        "    \n",
        "    # Compute utilities after each action, updating regrets deeper in the tree.\n",
        "    utility = np.zeros((game.num_distinct_actions(), game.num_players()))\n",
        "    for action in state.legal_actions():\n",
        "      prob = curr_policy[index][action]\n",
        "      utility[action] = calc_cfr(state.child(action), new_reach(reach, player, prob))\n",
        "\n",
        "    # Compute regrets at this state.\n",
        "    cfr_prob = np.prod(reach[:player]) * np.prod(reach[player+1:])\n",
        "    value = np.einsum('ap,a-\u003ep', utility, curr_policy[index])\n",
        "    for action in state.legal_actions():\n",
        "      regrets[index][action] += cfr_prob * (utility[action][player] - value[player])\n",
        "\n",
        "    # Return the value of this state for all players.\n",
        "    return value"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ueprrRS9NbO_"
      },
      "outputs": [],
      "source": [
        "game = pyspiel.load_game('kuhn_poker')\n",
        "# game = pyspiel.load_game('turn_based_simultaneous_game(game=goofspiel(imp_info=true,num_cards=4,players=2,points_order=descending))')\n",
        "policy = policy_lib.TabularPolicy(game)\n",
        "initial_state = game.new_initial_state()\n",
        "curr_policy = policy.action_probability_array.copy()\n",
        "regrets = np.zeros_like(policy.action_probability_array)\n",
        "eval_steps = []\n",
        "eval_nash_conv = []\n",
        "for step in range(129):\n",
        "  # Compute regrets\n",
        "  calc_cfr(initial_state, np.ones(1 + game.num_players()))\n",
        "\n",
        "  # Find the new regret-matching policy\n",
        "  floored_regrets = np.maximum(regrets, 1e-16)\n",
        "  sum_floored_regrets = np.sum(floored_regrets, axis=1, keepdims=True)\n",
        "  curr_policy = floored_regrets / sum_floored_regrets\n",
        "\n",
        "  # Update the average policy\n",
        "  lr = 1 / (1 + step)\n",
        "  policy.action_probability_array *= (1 - lr)\n",
        "  policy.action_probability_array += curr_policy * lr\n",
        "\n",
        "  # Evaluate the average policy\n",
        "  if step \u0026 (step-1) == 0:\n",
        "    nc = exploitability.nash_conv(game, policy)\n",
        "    eval_steps.append(step)\n",
        "    eval_nash_conv.append(nc)\n",
        "    print(f'Nash conv after step {step} is {nc}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3p1_k0gMocEp"
      },
      "outputs": [],
      "source": [
        "fig, ax = plt.subplots()\n",
        "ax.set_title(\"NashConv by CFR Iteration\")\n",
        "ax.plot(eval_steps, eval_nash_conv)\n",
        "fig.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qZI7PtXFoq0l"
      },
      "outputs": [],
      "source": [
        "fig, ax = plt.subplots()\n",
        "ax.set_title(\"NashConv by CFR Iteration (log-log scale)\")\n",
        "ax.loglog(eval_steps, eval_nash_conv)\n",
        "fig.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "orVhRjWBsBWx"
      },
      "outputs": [],
      "source": [
        "# Display the whole policy\n",
        "print_policy(policy)\n",
        "\n",
        "# How likely are we to bet with a Jack?\n",
        "alpha = policy.action_probability_array[policy.state_lookup['0']][1]\n",
        "print(f'P(bet with Jack) = alpha = {alpha:.3}')\n",
        "\n",
        "# How likely are we to bet with a King?\n",
        "pK = policy.action_probability_array[policy.state_lookup['2']][1]\n",
        "print(f'P(bet with King) = {pK:.3}, cf {alpha * 3:.3}')\n",
        "\n",
        "# How likely are we to call with a Queen?\n",
        "pQ = policy.action_probability_array[policy.state_lookup['1pb']][1]\n",
        "print(f'P(call with Queen after checking) = {pQ:.3}, cf {alpha + 1/3:.3}')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_rflT-jousiH"
      },
      "outputs": [],
      "source": [
        "def sample(actions_and_probs):\n",
        "  actions, probs = zip(*actions_and_probs)\n",
        "  return np.random.choice(actions, p=probs)\n",
        "\n",
        "def policy_as_list(policy, state):\n",
        "  return list(enumerate(policy.policy_for_key(state.information_state_string())))\n",
        "\n",
        "def env_action(state):\n",
        "  if state.is_chance_node():\n",
        "    p = state.chance_outcomes()\n",
        "  else:\n",
        "    p = policy_as_list(fixed_policy, state)\n",
        "  return sample(p)\n",
        "\n",
        "def softmax(x):\n",
        "  x = np.exp(x - np.max(x, axis=-1, keepdims=True))\n",
        "  return x / np.sum(x, axis=-1, keepdims=True)\n",
        "\n",
        "def generate_trajectory(state, player):\n",
        "  trajectory = []\n",
        "  while not state.is_terminal():\n",
        "    if state.current_player() == player:\n",
        "      action = sample(policy_as_list(rl_policy, state))\n",
        "      trajectory.append((rl_policy.state_index(state), action))\n",
        "    else:\n",
        "      action = env_action(state)\n",
        "    state.apply_action(action)\n",
        "  return trajectory, state.returns()[player]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y9Sn8g147_90"
      },
      "outputs": [],
      "source": [
        "fixed_policy = policy_lib.TabularPolicy(game)\n",
        "rl_policy = policy_lib.TabularPolicy(game)\n",
        "for _ in range(5):\n",
        "  print(generate_trajectory(game.new_initial_state(), player=0))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "R_L2KQ5u9Nc3"
      },
      "outputs": [],
      "source": [
        "# Run REINFORCE\n",
        "N = 10000\n",
        "lr = 0.01\n",
        "for step in range(N):\n",
        "  for player in (0, 1):\n",
        "    trajectory, reward = generate_trajectory(game.new_initial_state(), player)\n",
        "    for s, a in trajectory:\n",
        "      logits = np.log(rl_policy.action_probability_array[s])\n",
        "      logits[a] += lr * reward\n",
        "      rl_policy.action_probability_array[s] = softmax(logits)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8-25i6xFRUqC"
      },
      "outputs": [],
      "source": [
        "# Evaluate the policy\n",
        "def evaluate(state, rl_policy, player):\n",
        "  if state.is_terminal():\n",
        "    return state.returns()[player]\n",
        "  elif state.current_player() == player:\n",
        "    ap = policy_as_list(rl_policy, state)\n",
        "  elif state.is_chance_node():\n",
        "    ap = state.chance_outcomes()\n",
        "  else:\n",
        "    ap = policy_as_list(fixed_policy, state)\n",
        "  return sum(p * evaluate(state.child(a), rl_policy, player) for a, p in ap)\n",
        "\n",
        "def eval(rl_policy):\n",
        "  return (evaluate(game.new_initial_state(), rl_policy, player=0)\n",
        "        + evaluate(game.new_initial_state(), rl_policy, player=1))\n",
        "\n",
        "print_policy(rl_policy)\n",
        "eval(rl_policy)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IZX8G1sGLSBT"
      },
      "outputs": [],
      "source": [
        "# Evaluate the greedy policy\n",
        "greedy_policy = policy_lib.TabularPolicy(game)\n",
        "greedy_policy.action_probability_array = (np.eye(game.num_distinct_actions())\n",
        "              [np.argmax(rl_policy.action_probability_array, axis=-1)])\n",
        "\n",
        "print_policy(greedy_policy)\n",
        "eval(greedy_policy)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "CFR_and_REINFORCE.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
